import transformers
import logging
from colossalai.nn.lr_scheduler import LinearWarmupLR
from transformers import get_linear_schedule_with_warmup
from transformers import BertForPreTraining, RobertaForMaskedLM, RobertaConfig
from transformers import GPT2Config, GPT2LMHeadModel
from transformers import AutoTokenizer, AutoModelForMaskedLM
from colossalai.nn.optimizer import FusedAdam
from torch.optim import AdamW
from colossalai.core import global_context as gpc
import torch
import os
import sys
sys.path.append(os.getcwd())
from model.deberta_v2 import DebertaV2ForMaskedLM
from model.bert import BertForMaskedLM
import torch.nn as nn

from collections import OrderedDict

__all__ = ['get_model', 'get_optimizer', 'get_lr_scheduler', 'get_dataloader_for_pretraining']


def get_new_state_dict(state_dict, start_index=13):
    new_state_dict = OrderedDict() 
    for k, v in state_dict.items():
        name = k[start_index:]
        new_state_dict[name] = v 
    return new_state_dict


class LMModel(nn.Module):
    def __init__(self, model, config, args):
        super().__init__()

        self.checkpoint = args.checkpoint_activations
        self.config = config
        self.model = model
        if self.checkpoint:
            self.model.gradient_checkpointing_enable()

    def forward(self, input_ids, token_type_ids=None, attention_mask=None):
        # Only return lm_logits
        return self.model(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)


def get_model(args, logger):

    if args.mlm == 'bert':
        config = transformers.BertConfig.from_json_file(args.bert_config)
        model = BertForMaskedLM(config)
    elif args.mlm == 'deberta_v2':
        config = transformers.DebertaV2Config.from_json_file(args.bert_config)
        model = DebertaV2ForMaskedLM(config)
    else:
        raise Exception("Invalid mlm!")

    if len(args.load_pretrain_model) > 0:
        assert os.path.exists(args.load_pretrain_model)
        # load_checkpoint(args.load_pretrain_model, model, strict=False)
        m_state_dict = torch.load(args.load_pretrain_model, map_location=torch.device(f"cuda:{torch.cuda.current_device()}"))
        # new_state_dict = get_new_state_dict(m_state_dict)
        model.load_state_dict(m_state_dict, strict=True) # must insure that every process have identical parameters !!!!!!!
        logger.info("load model success")
            
    numel = sum([p.numel() for p in model.parameters()])
    if args.checkpoint_activations:
        model.gradient_checkpointing_enable()
    # model = LMModel(model, config, args)
    
    return config, model, numel


def get_optimizer(model, lr):
    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'gamma', 'beta', 'LayerNorm']

    # configure the weight decay for bert models
    optimizer_grouped_parameters = [{
        'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay': 0.1
    }, {
        'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay': 0.0
    }]
    optimizer = FusedAdam(optimizer_grouped_parameters, lr=lr, betas=[0.9, 0.95])
    return optimizer


def get_lr_scheduler(optimizer, total_steps, warmup_steps=2000, last_epoch=-1):
    # warmup_steps = int(total_steps * warmup_ratio)
    lr_scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps, num_training_steps=total_steps, last_epoch=last_epoch)
    # lr_scheduler = LinearWarmupLR(optimizer, total_steps=total_steps, warmup_steps=warmup_steps)
    return lr_scheduler


def save_ckpt(model, optimizer, lr_scheduler, path, epoch, shard, global_step):
    model_path = path + '_pytorch_model.bin'
    optimizer_lr_path = path + '.op_lrs'
    checkpoint = {}
    checkpoint['optimizer'] = optimizer.state_dict()
    checkpoint['lr_scheduler'] = lr_scheduler.state_dict()
    checkpoint['epoch'] = epoch
    checkpoint['shard'] = shard
    checkpoint['global_step'] = global_step
    model_state = model.state_dict() #each process must run model.state_dict()
    if gpc.get_global_rank() == 0:
        torch.save(checkpoint, optimizer_lr_path)
        torch.save(model_state, model_path)



